# -*- coding: utf-8 -*-
"""
Created on Sun Oct 24 18:21:58 2021

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.autograd import Variable
import time
import pandas as pd

import networkx as nx
import matplotlib as mpl
#mpl.use('Agg')
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from datetime import datetime
from torch.nn.utils import clip_grad_norm_
import copy
import random
import pickle
import os

from models import PhyCRNet,MaskedMSELoss,weightMSELoss
from data_utils import PrepareConcDataset
from utils import plot_conc,read_logs
import argparse


#%%test dataset
def test_model(epoch,writer):
    model.eval()
    losses_l1 = []

    for i in range(0,test_input.shape[0],args['batch_size']):
        inputs= test_input[i:i+args['batch_size']]
        label1,label1_coe=test_label1[i:i+args['batch_size']],test_label1_coe[i:i+args['batch_size']]
        label2,label2_mask =test_label2[i:i+args['batch_size']],test_label2_mask[i:i+args['batch_size']]
        
        inputs=Variable(inputs.to(device=device))
        label1,label1_coe=Variable(label1.to(device=device)),Variable(label1_coe.to(device=device))
        label2,label2_mask = Variable(label2.to(device=device)),Variable(label2_mask.to(device=device))
        outputs= model(inputs)
        
        loss1=weightMSELoss()(outputs, label1,label1_coe)
        loss2=MaskedMSELoss()(outputs, label2,label2_mask)
        losses_l1.append([loss1.item()*dev_list[0]**2,
                          loss2.item()*dev_list[0]**2])

    losses_l1=np.array(losses_l1).mean(axis=0)
    writer.add_scalar('test/loss1',losses_l1[0],epoch)
    writer.add_scalar('test/loss2',losses_l1[1],epoch)
    writer.add_scalar('test/all_loss',losses_l1.mean(),epoch)

    return losses_l1.mean()



#%%test dataset
def val_model(epoch,writer):
    model.eval()
    losses_l1 = []

    for i in range(0,val_input.shape[0],args['batch_size']):
        inputs= val_input[i:i+args['batch_size']]
        label1,label1_coe=val_label1[i:i+args['batch_size']],val_label1_coe[i:i+args['batch_size']]
        label2,label2_mask =val_label2[i:i+args['batch_size']],val_label2_mask[i:i+args['batch_size']]
        
        inputs=Variable(inputs.to(device=device))
        label1,label1_coe=Variable(label1.to(device=device)),Variable(label1_coe.to(device=device))
        label2,label2_mask = Variable(label2.to(device=device)),Variable(label2_mask.to(device=device))
        outputs= model(inputs)
        
        loss1=weightMSELoss()(outputs, label1,label1_coe)
        loss2=MaskedMSELoss()(outputs, label2,label2_mask)
        losses_l1.append([loss1.item()*dev_list[0]**2,
                          loss2.item()*dev_list[0]**2])

    losses_l1=np.array(losses_l1).mean(axis=0)
    writer.add_scalar('val/loss1',losses_l1[0],epoch)
    writer.add_scalar('val/loss2',losses_l1[1],epoch)
    writer.add_scalar('val/all_loss',losses_l1.mean(),epoch)

    return losses_l1.mean()



#%%

def train(epoch,writer,model,optimizer,lr_scheduler,args):
    t = time.time()
    train_mse=[]
    model.train()
    optimizer.zero_grad()


    for i in range(0,train_input.shape[0],args['batch_size']):
        inputs= train_input[i:i+args['batch_size']]
        label1,label1_coe=train_label1[i:i+args['batch_size']],train_label1_coe[i:i+args['batch_size']]
        label2,label2_mask =train_label2[i:i+args['batch_size']],train_label2_mask[i:i+args['batch_size']]
        
        inputs=Variable(inputs.to(device=device))
        label1,label1_coe=Variable(label1.to(device=device)),Variable(label1_coe.to(device=device))
        label2,label2_mask = Variable(label2.to(device=device)),Variable(label2_mask.to(device=device))
        outputs= model.forward(inputs)


        label1=label1+(torch.rand_like(label1)*2-1)*args['noise']*label1/100
        label2=label2+(torch.rand_like(label2)*2-1)*args['noise']*label2/100
        loss1=weightMSELoss()(outputs, label1,label1_coe)
        loss2=MaskedMSELoss()(outputs, label2,label2_mask)
        loss_train=loss1+1.0*loss2
        train_mse.append([loss1.item()*dev_list[0]**2,
                          loss2.item()*dev_list[0]**2])
        
        optimizer.zero_grad()
        loss_train.backward()
        # clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
    lr_scheduler.step()
    
    train_mse=np.array(train_mse).mean(axis=0)
    writer.add_scalar('learning_rate',optimizer.state_dict()['param_groups'][0]['lr'],epoch)
    writer.add_scalar('train/loss1',train_mse[0],epoch)
    writer.add_scalar('train/loss2',train_mse[1],epoch)
    writer.add_scalar('train/all_loss',train_mse.mean(),epoch)
    
    mean_l1=np.inf
    if epoch%10==0:
        val_mse=val_model(epoch,writer)
        mean_l1=val_mse
        test_mse=test_model(epoch,writer)
        print('Epoch: {:04d}'.format(epoch+1),
              'loss_train: {:.4f}'.format(train_mse.mean()),          
              'mseloss_val: {:.4f}'.format(val_mse),
              'mseloss_test: {:.4f}'.format(test_mse),
              'time: {:.4f}s'.format(time.time() - t))
   
    return (mean_l1,train_mse.mean())

# In[192]:


def run_exp(model,optimizer,lr_scheduler,args,i):
    loss_values = []
    bad_counter = 0
    best = np.inf
    best_epoch = 0
    best_model=None
    
    path=loot_path+'log_conc/'+ datetime.now().strftime('%Y%m%d-%H%M%S')
    writer = SummaryWriter(path)
    for epoch in range(args['epochs'][i]):
        loss_values.append(train(epoch,writer,model,optimizer,lr_scheduler,args))
        epoch_model=copy.deepcopy(model.state_dict())
        
        if np.isnan(loss_values[-1][1]):
            
            if best_model:
                model_path=path+f'/{best_epoch}.pkl'
                torch.save(best_model,model_path)
            else:
                model_path=path+f'/{epoch}.pkl'
                torch.save(epoch_model,model_path)
            with open(f'{path}/args.pkl', 'wb') as f:
                pickle.dump(args, f)
                
            model.load_state_dict(torch.load(model_path))
            plot_conc(args,model,path,device,)
            
            return np.array(loss_values).min(),model
        else:
            pass

        if loss_values[-1][1]<0.2: 
            if loss_values[-1][0] < best:
                best = loss_values[-1][0]
                best_epoch = epoch
                best_model=epoch_model
                bad_counter = 0
            else:
                bad_counter += 1
        else:
            bad_counter = 0

        if bad_counter == args['patience']:
            print('Oops, early stopping!')
            break

    writer.close()
    if best_model:
        model_path=path+f'/{best_epoch}.pkl'
        torch.save(best_model,model_path)
    else:
        model_path=path+f'/{epoch}.pkl'
        torch.save(epoch_model,model_path)
    with open(f'{path}/args.pkl', 'wb') as f:
        pickle.dump(args, f)
        
    model.load_state_dict(torch.load(model_path))
    plot_conc(args,model,path,device,'sig')
    
#%%Train conclstm

# read hyper-param settings
parser = argparse.ArgumentParser()
parser.add_argument("--exp", default='exp13', type=str, help="configuration file path")
parser.add_argument('--buffer', type=int, default='20')
parser.add_argument('--end_lr', type=float, default='1e-6')
parser.add_argument('--peak_lr', type=float, default='1e-3')
parser.add_argument('--batch_size', type=int, default='10')
parser.add_argument('--power', type=int, default='1')
parser.add_argument('--dropout', type=float, default='0.01')
args_input = parser.parse_args()


seed=1000
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
set_seed(seed)


train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,\
test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
mmin_list,dev_list,conc_name,edge_list,time_scaler=PrepareConcDataset()

[batch_size,time_len,input_channel,h,w] = train_input.size()  
args={}
args['epochs'] = [500,1000]
args['input_dim'] = input_channel
args['output_dim']= train_label2.shape[2]
args['patience']=100
args['Clamp_A'] = True
args['MAX_EVALS']=2
loot_path='./exp/'+args_input.exp+'/'
args['cuda'] = torch.cuda.is_available()

#if cuda is available use cuda
if args['cuda']:
    device = torch.device('cuda')


#%%
# param_grid={
#     'num_layers':[3,4],
#     'hidden_dim':[10],
#     'weight_decay':[1e-4],
#     'warmup_updates':[10],
#     'tot_updates':[1000],
#     'peak_lr':[5e-3],
#     'end_lr':[5e-3],
#     'noise':[0,10],
#     'optim':['adam'],
#     'time_window':[30],
#     'train_ratio':[0.7],
#     'batch_size':[5,10],
#     }

param_grid={
    'hidden_dim':[[4,8,1]],
    'input_kernel_size':[3],
    'input_stride':[1],
    'padding_mode':['reflect'],
    'num_layers':[3],
    'weight_decay':[1e-4],
    'dropout':[args_input.dropout],
    'warmup_updates':[10],
    'tot_updates':[800],
    'peak_lr':[args_input.peak_lr],
    'end_lr':[args_input.end_lr],
    'noise':[5],
    'time_window':[1],
    'batch_size':[args_input.batch_size,],
    'res':[30,],
    'buffer':[args_input.buffer],
    'pred_day':[1],
    'gap':[1,],
    'test_day':[60],
    'val_ratio':[0.01],
    'power':[args_input.power],
    }

from utils import read_pkl,write_pkl
args=read_pkl('exp/20220708-104119/args.pkl')

best_score = np.inf
best_count=-1
best_model=None
best_hyperparams=None
for i in range(args['MAX_EVALS']):
    print("-"*10,"search_model",i,"/",args['MAX_EVALS'],"-"*10)
    # random.seed(i)	
    hyperparameters = {k: random.sample(v, 1)[0] for k, v in param_grid.items()}
 

    for k,v in hyperparameters.items():
        args[k]=v
    print('a')  
    train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
    val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,\
    test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
    mmin_list,dev_list,conc_name,edge_list,time_scaler=PrepareConcDataset(args['time_window'],args['res'],
                                                                          args['pred_day'],args['buffer'],
                                                                          args['gap'],args['test_day'],
                                                                          args['val_ratio'])
      
    #reconfig model
    print('b')
    model = PhyCRNet(args['input_dim'],
                     args['hidden_dim'],
                     args['output_dim'],
                     args['input_kernel_size'],
                     args['input_stride'],
                     args['padding_mode'],
                     args['num_layers'], 
                     args['dropout'],
                     args['warmup_updates'],
                     args['tot_updates'],
                     args['peak_lr'],
                     args['end_lr'],
                     args['power'],
                     args['weight_decay'])
    
    optimizer, lr_scheduler=model.configure_optimizers()
    model=model.to(device=device)

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
        else:
            nn.init.uniform_(p)

    num_parameters=sum([p.numel() for p in model.parameters() if p.requires_grad])
    print('*'*60)
    print(f'{loot_path} model {i} has {num_parameters} learnable parameters')
    print('*'*60)
    run_exp(model,optimizer,lr_scheduler,args,i)
    num_parameters=sum([p.numel() for p in model.parameters() if p.requires_grad])
    print('*'*60)
    print(f'{loot_path} model {i} has {num_parameters} learnable parameters')
    print('*'*60)

#%%
import numpy as np
import matplotlib
import matplotlib as mpl
cmap = matplotlib.cm.RdYlBu
#mpl.use('qtAgg')  #('qtagg')
import matplotlib.pyplot as plt

#%%
import pickle
well_index=pickle.load(open('../cnn_data/well_index.pkl','rb'))[:47]
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(9,9))

x_ticks = [0, 60, 120, 180, 240,300] 
y_ticks = x_ticks 
ax.set_xlim(-30,330)
ax.set_ylim(-30,330)
ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)

ax.grid(True)

ax.set_aspect('equal', adjustable='box')

ax.scatter(well_index[:20,1]-70,well_index[:20,2]-100,marker='o',label='produciton well')
ax.scatter(well_index[20:,1]-70,well_index[20:,2]-100,marker='*',label='injection well')
for i in range(20):
	ax.text(well_index[i,1]-68,well_index[i,2]-95, well_index[i,0], fontsize=12, ha='center', va='center')
for i in range(20,47):
	ax.text(well_index[i,1]-68,well_index[i,2]-95, well_index[i,0], fontsize=12, ha='center', va='center')

plt.savefig('well_index.png')
plt.savefig('well_index.eps',format='eps')
#%%
#train_label2=train_label2.numpy()
#train_input=train_input.numpy()
if isinstance(train_input, np.ndarray):
    pass
else:
    train_input=train_input.numpy()
    pass
if isinstance(train_label2, np.ndarray):
    pass
else:
    train_label2=train_label2.numpy()
    pass
harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
                    [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
                    [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
                    [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
                    [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
                    [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1],
                    [0.1, 2.0, 0.0, 1.4, 0.0, 1.9, 6.3]])

var=['time','flux','acid']
png_dir='111/'
import os 
if not os.path.exists(png_dir):
	os.makedirs(png_dir)
vmax_arr=np.max(train_input[range(0,260,30),0,:],axis=(0,2,3))
vmin_arr=np.min(train_input[range(0,260,30),0,:],axis=(0,2,3))
for i in range(0,260,30):
	for j in range(3):
		
		harvest=train_input[i,0,j]
		fig, ax = plt.subplots()
		im = ax.imshow(harvest,vmin=vmin_arr[j], vmax=vmax_arr[j])
		
		ax.set_xticks([])
		ax.set_yticks([])
		
		ax.spines['left'].set_visible(False)
		ax.spines['right'].set_visible(False)
		ax.spines['bottom'].set_visible(False)
		ax.spines['top'].set_visible(False)
		
		
		ax.set_title(var[j])
		fig.tight_layout()
		plt.savefig(f'111/harvest_{var[j]}_{i}')
		plt.savefig(f'111/harvest_{var[j]}_{i}.eps',format='eps')
		plt.show()

#%%
for i in range(15,23):

	harvest=train_label2[i,0,0].numpy()

	indices = np.where(harvest == -10.00)

	harvest[indices] += np.random.rand(len(indices[0]))+6
	fig, ax = plt.subplots()
	im = ax.imshow(harvest)
	
	ax.set_xticks([])
	ax.set_yticks([])
	
	ax.spines['left'].set_visible(False)
	ax.spines['right'].set_visible(False)
	ax.spines['bottom'].set_visible(False)
	ax.spines['top'].set_visible(False)
	
	
	ax.set_title('uran')
	fig.tight_layout()
	plt.savefig(f'111/harvest_uran_{i}')
	plt.savefig(f'111/harvest_uran_{i}.eps',format='eps')
	plt.show()

#%%
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
#mpl.use('qtAgg')  #('qtagg')

x = np.linspace(-10, 10, 100)
y = np.linspace(-10, 10, 100)
X, Y = np.meshgrid(x, y)


U =np.linspace(-10, 10, 100)
V = np.linspace(-10, 10, 100)
U,V=X,Y

plt.streamplot(X, Y, U, V)

plt.title('Streamplot')
plt.xlabel('X')
plt.ylabel('Y')
plt.savefig('streamplot.eps',format='eps')

plt.show()

#%%
#'''
model.convlstm[0].lstm_cell.weight_ih_l0.numel()
model.convlstm[0].lstm_cell.weight_hh_l0.numel()
model.convlstm[0].lstm_cell.bias_ih_l0.numel()

total_params=0
for name,param in model.named_parameters():
	if 'weight' in name:
		total_params+=param.numel()
print('总权重数量：',total_params)
#'''